import torch
from torch import nn
import numpy as np

import torch.nn.functional as F

def network_weight_gaussian_init(net: nn.Module):
    with torch.no_grad():
        for m in net.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.zeros_(m.bias)
            else:
                continue

    return net

def cross_entropy(logit, target):
    # target must be one-hot format!!
    prob_logit = F.log_softmax(logit, dim=1)
    loss = -(target * prob_logit).sum(dim=1).mean()
    return loss

def gradnorm(train_loader, networks, train_mode=False, num_batch=-1, num_classes=100, verbose=False):
    device = torch.cuda.current_device()
    for network in networks:
        if train_mode:
            network.train()
        else:
            network.eval()

    inputs, targets = next(iter(train_loader))
    inputs = inputs.to(device)
    targets = targets.to(device)
    targets_onehot = torch.nn.functional.one_hot(targets, num_classes=num_classes).float()

    network_gradnorm = []
    for net in networks:
        # keep signs of all params
        net.zero_grad()
        output = net(inputs)

        assert isinstance(output, tuple)
        output = output[1]

        loss = cross_entropy(output, targets_onehot)
        loss.backward()
        norm2_sum = 0
        with torch.no_grad():
            for p in net.parameters():
                if hasattr(p, 'grad') and p.grad is not None:
                    norm2_sum += torch.norm(p.grad) ** 2

        grad_norm = float(torch.sqrt(norm2_sum))

        network_gradnorm.append(grad_norm)

    return network_gradnorm
